from torch.utils.data import Dataset
import torch
import numpy as np
import random

class PFSPDataset(Dataset):
    def __init__(self, batch_size, n_jobs, n_mc, mode, seed=None):
        super(PFSPDataset, self).__init__()
        self.n_jobs = n_jobs
        self.n_mc = n_mc
        self.batch_size = batch_size
        self.mode = mode
        self.seed = seed
        self.set_seed(seed)
        self.data = self.make_dataset() 

    def set_seed(self, seed):
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
    
    def __len__(self):
        return self.n_samples
    

    def __getitem__(self, idx):
        return self.data[idx]
    
    def make_dataset(self):
        if self.mode =='normal':
            #dataset = torch.normal(mean=mean, std=std_dev, size=(self.batch_size, self.n_jobs, self.n_mc))
            dataset = torch.Tensor(np.random.normal(6, 6, size=(self.batch_size, self.n_jobs, self.n_mc)))
            dataset = torch.clamp(dataset, min=0)
        elif self.mode == 'Gamma':
            dataset = torch.Tensor(np.random.chisquare(1, size=(self.batch_size, self.n_jobs, self.n_mc)))
        elif self.mode == 'uniform':
            dataset = torch.Tensor(np.random.randint(1,100,(self.batch_size, self.n_jobs, self.n_mc)))
        else:
            dataset  = 5 + (10 - 5) * torch.rand(self.batch_size, self.n_jobs, self.n_mc)

        return dataset
    

def get_random_problems(batch_size, n_jobs, n_mc, mode):
    if mode =='normal':
        dataset = torch.Tensor(np.random.normal(6, 6, size=(batch_size, n_jobs, n_mc)))
        dataset = torch.clamp(dataset, min=0)
    elif mode == 'Gamma':
        dataset = torch.Tensor(np.random.chisquare(1, size=(batch_size, n_jobs, n_mc)))
    elif mode == 'uniform':
        dataset = torch.Tensor(np.random.randint(1,100,(batch_size, n_jobs, n_mc)))
    else:
        dataset  = 5 + (10 - 5) * torch.rand(batch_size, n_jobs, n_mc)

    return dataset


def get_random_eval_problems(batch_size, n_jobs, n_mc, mode, seed):
    if seed is not None:
        torch.manual_seed(seed)  # 시드 고정
        random.seed(seed)
        np.random.seed(seed)
        torch.cuda.manual_seed_all(seed)
    if mode =='normal':
        dataset = torch.Tensor(np.random.normal(6, 6, size=(batch_size, n_jobs, n_mc)))
        dataset = torch.clamp(dataset, min=0)
    elif mode == 'Gamma':
        dataset = torch.Tensor(np.random.chisquare(1, size=(batch_size, n_jobs, n_mc)))
    elif mode == 'uniform':
        dataset = torch.Tensor(np.random.randint(1,100,(batch_size, n_jobs, n_mc)))
    else:
        dataset  = 5 + (10 - 5) * torch.rand(batch_size, n_jobs, n_mc)
    return dataset